{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7yzpcOoGI4jZ"
      },
      "source": [
        "Copyright 2022 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xzGpqCdpI8BL"
      },
      "source": [
        "## Continuous simulation study\n",
        "\n",
        "This notebook conducts a simulation study using continuous observations, generating the results shown in Table 2. We use a single set of source and target domains where $p(U=1)=0.1$ and $q(U=1)=0.9$. We compare the proposed adaptation algorithms to a series of baselines. This notebook relies on previously executing `colab/synthetic_data_to_file.ipynb`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "y9rtChCqXQbO",
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "import ml_collections as mlc\n",
        "import scipy\n",
        "import os\n",
        "import pandas as pd\n",
        "import sklearn\n",
        "import re\n",
        "import io\n",
        "from sklearn.preprocessing import OneHotEncoder, MinMaxScaler\n",
        "from sklearn.metrics import roc_auc_score\n",
        "from sklearn.metrics import accuracy_score\n",
        "from sklearn.metrics import log_loss\n",
        "import latent_shift_adaptation.methods.algorithms_sknp as algorithms_sknp\n",
        "from latent_shift_adaptation.methods.algorithms_sknp import get_classifier\n",
        "from latent_shift_adaptation.utils import gumbelmax_vae_ci, gumbelmax_vae\n",
        "from latent_shift_adaptation.methods import baseline, erm\n",
        "from latent_shift_adaptation.methods.vae import gumbelmax_vanilla, gumbelmax_graph\n",
        "from latent_shift_adaptation.methods.shift_correction import cov, label, bbse, bbsez\n",
        "from IPython.display import display"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ef5TvfS_I4je"
      },
      "outputs": [],
      "source": [
        "ITERATIONS = 10 # Set to 10 to replicate experiments in paper\n",
        "EPOCHS = 200 # Set to 200 to replicate experiments in paper"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aBO09d6-k9k4"
      },
      "outputs": [],
      "source": [
        "DEFAULT_LOSS = tf.keras.losses.CategoricalCrossentropy(from_logits=True)\n",
        "\n",
        "def mlp(num_classes, width, input_shape, learning_rate,\n",
        "        loss=DEFAULT_LOSS, metrics=[]):\n",
        "  \"\"\"Multilabel Classification.\"\"\"\n",
        "  model_input = tf.keras.Input(shape=input_shape)\n",
        "  # hidden layer\n",
        "  if width:\n",
        "    x = tf.keras.layers.Dense(\n",
        "        width, use_bias=True, activation='relu'\n",
        "    )(model_input)\n",
        "  else:\n",
        "    x = model_input\n",
        "  model_outuput = tf.keras.layers.Dense(num_classes,\n",
        "                                        use_bias=True,\n",
        "                                        activation=\"linear\")(x)  # get logits\n",
        "  opt = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)\n",
        "  model = tf.keras.models.Model(model_input, model_outuput)\n",
        "  model.build(input_shape)\n",
        "  model.compile(loss=loss, optimizer=opt, metrics=metrics)\n",
        "\n",
        "  return model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GV9oY0aUOitj"
      },
      "source": [
        "# Data setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IIB09XMpZ8LN"
      },
      "outputs": [],
      "source": [
        "xlabel = 'x'  # or 'x', 'x_scaled'\n",
        "SEED = 0"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bde2UTzYTYM7"
      },
      "outputs": [],
      "source": [
        "\n",
        "# Convert data to dataframe format\n",
        "def pack_to_df(samples_dict):\n",
        "  return pd.concat({key: get_squeezed_df(value) for key, value in samples_dict.items()}).reset_index(level=-1, drop=True).rename_axis('partition').reset_index()\n",
        "\n",
        "# Extract dataframe format back to dict format\n",
        "def extract_from_df(samples_df, cols=['u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits', 'y_one_hot', 'w_binary', 'w_one_hot', 'u_one_hot', 'x_scaled']):\n",
        "  \"\"\"\n",
        "  Extracts dict of numpy arrays from dataframe\n",
        "  \"\"\"\n",
        "  result = {}\n",
        "  for col in cols:\n",
        "    if col in samples_df.columns:\n",
        "      result[col] = samples_df[col].values\n",
        "    else:\n",
        "      match_str = f\"^{col}_\\d$\"\n",
        "      r = re.compile(match_str, re.IGNORECASE)\n",
        "      matching_columns = list(filter(r.match, samples_df.columns))\n",
        "      if len(matching_columns) == 0:\n",
        "        continue\n",
        "      result[col] = samples_df[matching_columns].to_numpy()\n",
        "  return result\n",
        "\n",
        "def extract_from_df_nested(samples_df, cols=['u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits', 'y_one_hot', 'w_binary', 'w_one_hot', 'u_one_hot', 'x_scaled']):\n",
        "  \"\"\"\n",
        "  Extracts nested dict of numpy arrays from dataframe with structure {domain: {partition: data}}\n",
        "  \"\"\"\n",
        "  result = {}\n",
        "  if 'domain' in samples_df.keys():\n",
        "    for domain in samples_df['domain'].unique():\n",
        "      result[domain] = {}\n",
        "      domain_df = samples_df.query('domain == @domain')\n",
        "      for partition in domain_df['partition'].unique():\n",
        "        partition_df = domain_df.query('partition == @partition')\n",
        "        result[domain][partition] = extract_from_df(partition_df, cols=cols)\n",
        "  else:\n",
        "    for partition in samples_df['partition'].unique():\n",
        "        partition_df = samples_df.query('partition == @partition')\n",
        "        result[partition] = extract_from_df(partition_df, cols=cols)\n",
        "  return result"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tA2Cvu3EqFP6"
      },
      "outputs": [],
      "source": [
        "folder_id = \"./tmp_data\"\n",
        "filename_source = \"synthetic_multivariate_num_samples_10000_w_coeff_1_p_u_0_0.9.csv\"\n",
        "filename_target = \"synthetic_multivariate_num_samples_10000_w_coeff_1_p_u_0_0.1.csv\"\n",
        "data_df_source = pd.read_csv(os.path.join(folder_id, filename_source))\n",
        "data_df_target = pd.read_csv(os.path.join(folder_id, filename_target))\n",
        "data_dict_source = extract_from_df_nested(data_df_source)\n",
        "data_dict_target = extract_from_df_nested(data_df_target)\n",
        "data_dict_all = dict(source=data_dict_source, target=data_dict_target)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "67pBVhpzTfZs"
      },
      "outputs": [],
      "source": [
        "# Convert to TF dataset\n",
        "ds_source = {\n",
        "    key: tf.data.Dataset.from_tensor_slices(\n",
        "        (value['x'], value['y_one_hot'], value['c'], value['w_one_hot'], value['u_one_hot']), \n",
        "    ) for key, value in data_dict_all['source'].items()\n",
        "}\n",
        "ds_target = {\n",
        "    key: tf.data.Dataset.from_tensor_slices(\n",
        "        (value['x'], value['y_one_hot'], value['c'], value['w_one_hot'], value['u_one_hot']), \n",
        "    ) for key, value in data_dict_all['target'].items()\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l5chcXBWfiyW"
      },
      "outputs": [],
      "source": [
        "#@title Data exploration\n",
        "batch = next(iter(ds_source['train']))\n",
        "batch[0].shape, batch[1].shape, batch[2].shape, batch[3].shape, batch[4].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RaM17X8cgphU"
      },
      "outputs": [],
      "source": [
        "ds_source.keys()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nLDMf281fsvU"
      },
      "outputs": [],
      "source": [
        "batch_size = 128\n",
        "for split in ds_source.keys():\n",
        "  ds_source[split] = ds_source[split].repeat().shuffle(1000).batch(batch_size)\n",
        "  ds_target[split] = ds_target[split].repeat().shuffle(1000).batch(batch_size)\n",
        "\n",
        "batch = next(iter(ds_source['train']))\n",
        "x_dim = batch[0].shape[1]\n",
        "c_dim = batch[2].shape[1]\n",
        "w_dim = batch[3].shape[1]\n",
        "u_dim = batch[4].shape[1]\n",
        "num_classes = 2\n",
        "test_fract = 0.2\n",
        "val_fract = 0.1\n",
        "\n",
        "num_examples = 10_000\n",
        "steps_per_epoch = num_examples // batch_size\n",
        "steps_per_epoch_test = int(steps_per_epoch * test_fract)\n",
        "steps_per_epoch_val = int(steps_per_epoch * val_fract)\n",
        "steps_per_epoch_train = steps_per_epoch - steps_per_epoch_test - steps_per_epoch_val\n",
        "\n",
        "pos = mlc.ConfigDict()\n",
        "pos.x, pos.y, pos.c, pos.w, pos.u = 0, 1, 2, 3, 4"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IUzEiChNRaSq"
      },
      "outputs": [],
      "source": [
        "batch[0].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jlq9KPcTGSkF"
      },
      "outputs": [],
      "source": [
        "x = tf.convert_to_tensor(batch[0])\n",
        "var_x = tf.math.reduce_variance(x).numpy()\n",
        "h_y = scipy.stats.entropy(np.argmax(batch[1], axis=1))\n",
        "h_c = scipy.stats.entropy(batch[2].numpy().reshape((-1,)))\n",
        "h_w = scipy.stats.entropy(np.argmax(batch[3], axis=1))\n",
        "weight_x = 1. / var_x\n",
        "weight_y = 1. / h_y\n",
        "weight_c = 1. / h_c\n",
        "weight_w = 1. / h_w"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8DXoiMKKRTlb"
      },
      "outputs": [],
      "source": [
        "x_dim, c_dim, w_dim, u_dim"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jZtiOAFlRmog"
      },
      "outputs": [],
      "source": [
        "weight_x, weight_y, weight_c, weight_w"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YWUnZAIQOokS"
      },
      "source": [
        "# Methods"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Yu-pMja63PPM"
      },
      "source": [
        "Global Params"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qC0pr9atJwRE"
      },
      "outputs": [],
      "source": [
        "tf.random.set_seed(SEED)\n",
        "np.random.seed(SEED)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xne3e08tRsj0"
      },
      "outputs": [],
      "source": [
        "evals = {  # evaluation functions\n",
        "    \"cross-entropy\": tf.keras.metrics.CategoricalCrossentropy(),\n",
        "    \"accuracy\": tf.keras.metrics.CategoricalAccuracy(),\n",
        "    \"auc\": tf.keras.metrics.AUC(multi_label = False)\n",
        "}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fnc9qqGJ3OkQ"
      },
      "outputs": [],
      "source": [
        "reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(\n",
        "    monitor='loss', min_delta=0.01, factor=0.1, patience=20,\n",
        "    min_lr=1e-7)\n",
        "\n",
        "callbacks = [reduce_lr]\n",
        "\n",
        "do_calib = True\n",
        "evaluate = tf.keras.metrics.CategoricalCrossentropy()\n",
        "\n",
        "learning_rate = 0.01  #@param {type:\"number\"}\n",
        "width = 100  #@param {type:\"number\"}\n",
        "\n",
        "train_kwargs = {\n",
        "    'epochs': EPOCHS,\n",
        "    'steps_per_epoch':steps_per_epoch_train,\n",
        "    'verbose': True,\n",
        "    'callbacks':callbacks\n",
        "    }\n",
        "\n",
        "val_kwargs = {\n",
        "    'epochs': EPOCHS,\n",
        "    'steps_per_epoch':steps_per_epoch_val,\n",
        "    'verbose': False,\n",
        "    'callbacks':callbacks\n",
        "    }\n",
        "\n",
        "test_kwargs = {'verbose': False,\n",
        "               'steps': steps_per_epoch_test}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k_IF-FOtzzNk"
      },
      "outputs": [],
      "source": [
        "result = {}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mz27IwsDtwRn"
      },
      "source": [
        "## Evaluation metrics\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8VuEW5EBt4JD"
      },
      "outputs": [],
      "source": [
        "# Define sklearn evaluation metrics\n",
        "def soft_accuracy(y_true, y_pred, threshold=0.5, **kwargs):\n",
        "  return sklearn.metrics.accuracy_score(y_true, y_pred \u003e= threshold, **kwargs)\n",
        "\n",
        "def soft_balanced_accuracy(y_true, y_pred, threshold=0.5, **kwargs):\n",
        "  return sklearn.metrics.balanced_accuracy_score(y_true, y_pred \u003e= threshold, **kwargs)\n",
        "\n",
        "def log_loss64(y_true, y_pred, **kwargs):\n",
        "  return sklearn.metrics.log_loss(y_true, y_pred.astype(np.float64), **kwargs)\n",
        "\n",
        "evals_sklearn = {\n",
        "    'cross-entropy': log_loss64,\n",
        "    'accuracy': soft_accuracy, \n",
        "    'balanced_accuracy': soft_balanced_accuracy,\n",
        "    'auc': sklearn.metrics.roc_auc_score\n",
        "}\n",
        "def evaluate_clf():\n",
        "  result_dict = {}\n",
        "  for metric in evals_sklearn.keys():\n",
        "    result_dict[metric] = {}\n",
        "  \n",
        "  y_pred_source = clf.predict(data_dict_all['source']['test']['x'])\n",
        "  y_pred_target = clf.predict(data_dict_all['target']['test']['x'])\n",
        "  if 'cbm' in method:\n",
        "    # hacky workaround for now\n",
        "    y_pred_source = y_pred_source[1]\n",
        "    y_pred_target = y_pred_target[1]\n",
        "  y_pred_source = y_pred_source.numpy()[:, 1] if tf.is_tensor(y_pred_source) else y_pred_source[:, 1]\n",
        "  # y_pred_source = y_pred_source.numpy()[:, 1]\n",
        "  y_pred_target = y_pred_target.numpy()[:, 1] if tf.is_tensor(y_pred_target) else y_pred_target[:, 1]\n",
        "  y_true_source = data_dict_all['source']['test']['y']\n",
        "  y_true_target = data_dict_all['target']['test']['y']\n",
        "  \n",
        "\n",
        "  for metric in evals_sklearn.keys():\n",
        "    result_dict[metric]['eval_on_source'] = evals_sklearn[metric](y_true_source, y_pred_source)\n",
        "    result_dict[metric]['eval_on_target'] = evals_sklearn[metric](y_true_target, y_pred_target)\n",
        "  return result_dict\n",
        "\n",
        "def evaluate_sk(model):\n",
        "  result_dict = {}\n",
        "  for metric in evals_sklearn.keys():\n",
        "    result_dict[metric] = {}\n",
        "  y_pred_source = model.predict_proba(data_dict_all['source']['test']['x'])[:, -1]\n",
        "  y_pred_target = model.predict_proba(data_dict_all['target']['test']['x'])[:, -1]\n",
        "  y_true_source = data_dict_all['source']['test']['y']\n",
        "  y_true_target = data_dict_all['target']['test']['y']\n",
        "  for metric in evals_sklearn.keys():\n",
        "    result_dict[metric]['eval_on_source'] = evals_sklearn[metric](y_true_source, y_pred_source)\n",
        "    result_dict[metric]['eval_on_target'] = evals_sklearn[metric](y_true_target, y_pred_target)\n",
        "  return result_dict"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0OtB_OMCBKEC"
      },
      "source": [
        "## Sklearn baselines"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mSgB_fhPt78U"
      },
      "outputs": [],
      "source": [
        "result_sk = {}\n",
        "method = 'erm_source_sk'\n",
        "result_list = []\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  np.random.seed(seed)\n",
        "  model = get_classifier('mlp')\n",
        "  model.fit(data_dict_all['source']['train']['x'], data_dict_all['source']['train']['y'])\n",
        "  result_list.append(evaluate_sk(model))\n",
        "result_sk[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])\n",
        "method = 'erm_target_sk'\n",
        "result_list = []\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  np.random.seed(seed)\n",
        "  model = get_classifier('mlp')\n",
        "  model.fit(data_dict_all['target']['train']['x'], data_dict_all['target']['train']['y'])\n",
        "  result_list.append(evaluate_sk(model))\n",
        "result_sk[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])\n",
        "method = 'known_u_sk'\n",
        "result_list = []\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  np.random.seed(seed)\n",
        "  lsa_pred_probs_target = algorithms_sknp.latent_shift_adaptation(\n",
        "      x_source=data_dict_all['source']['train']['x'],\n",
        "      y_source=data_dict_all['source']['train']['y'],\n",
        "      u_source=data_dict_all['source']['train']['u'],\n",
        "      x_target=data_dict_all['target']['test']['x'],\n",
        "      model_type='mlp')[:, -1]\n",
        "  lsa_result_dict = {}\n",
        "  for metric in evals_sklearn.keys():\n",
        "    lsa_result_dict[metric] = {}\n",
        "    lsa_result_dict[metric]['eval_on_source'] = np.nan\n",
        "    lsa_result_dict[metric]['eval_on_target'] = evals_sklearn[metric](data_dict_all['target']['test']['y'], lsa_pred_probs_target)\n",
        "  result_list.append(lsa_result_dict)\n",
        "result_sk[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u2ladlnzuALr"
      },
      "outputs": [],
      "source": [
        "metrics = list(evals_sklearn.keys())\n",
        "result_df_sk=pd.concat(result_sk).reset_index(level=1, drop=True).rename_axis('method').reset_index()\n",
        "result_df_sk\n",
        "results_mean_sk = result_df_sk.groupby(['method', 'eval_set'])[metrics].mean()\n",
        "results_sd_sk = result_df_sk.groupby(['method', 'eval_set'])[metrics].std()\n",
        "results_formatted_sk = results_mean_sk.applymap(lambda x: '{:.4f}'.format(x)) + u' \\u00B1 ' + results_sd_sk.applymap(lambda x: '{:.4f}'.format(x))\n",
        "results_formatted_sk\n",
        "results_long_sk = results_formatted_sk.reset_index().melt(id_vars=['method', 'eval_set'], value_vars=metrics, var_name='metric', value_name='performance')\n",
        "print(results_long_sk.head())\n",
        "results_pivot_sk = results_long_sk.pivot(index=['method', 'metric'], columns='eval_set')\n",
        "display(results_pivot_sk.query('metric == \"cross-entropy\"'))\n",
        "display(results_pivot_sk.query('metric == \"accuracy\"'))\n",
        "display(results_pivot_sk.query('metric == \"auc\"'))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "12QVCGIROqlx"
      },
      "source": [
        "## ERM - source\n",
        "This trains using empirical risk minimization on the source distribution."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6db3Rv5vV5m_"
      },
      "source": [
        "Choose your input. If it's only X, set `inputs='x'`. If it's both X and C, set `inputs='xc'`, and so on."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XFJydTPpWD6L"
      },
      "outputs": [],
      "source": [
        "inputs = 'x'  #@param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RrZQTf7ZViBV"
      },
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GHFqsZCTgCjz"
      },
      "outputs": [],
      "source": [
        "input_shape = x_dim * ('x' in inputs) + c_dim * ('c' in inputs) \\\n",
        "  + u_dim * ('u' in inputs) + w_dim * ('w' in inputs)\n",
        "input_shape = (input_shape, )\n",
        "model = mlp(num_classes=num_classes, width=width,\n",
        "            input_shape=input_shape, learning_rate=learning_rate,\n",
        "            metrics=['accuracy'])\n",
        "model.summary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CgTZvQCxKXjM"
      },
      "outputs": [],
      "source": [
        "method = 'erm_source'\n",
        "result_list = []\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  model = mlp(num_classes=num_classes, width=width,\n",
        "            input_shape=input_shape, learning_rate=learning_rate,\n",
        "            metrics=['accuracy'])\n",
        "  clf = erm.Method(model, evaluate, inputs=inputs, dtype=tf.float32, pos=pos)\n",
        "  clf.fit(ds_source['train'], ds_source['val'], ds_target['train'],\n",
        "          steps_per_epoch_val, **train_kwargs)\n",
        "  result_list.append(evaluate_clf())\n",
        "result[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CqVBheJsBctw"
      },
      "source": [
        "## ERM - target\n",
        "This trains using empirical risk minimization on the target distribution, assuming access to both x and y in the target domain."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vS1an4tzBfZc"
      },
      "outputs": [],
      "source": [
        "inputs = 'x'  #@param {type:\"string\"}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gB6cfpOWMCeN"
      },
      "outputs": [],
      "source": [
        "method = 'erm_target'\n",
        "result_list = []\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  model = mlp(num_classes=num_classes, width=width,\n",
        "            input_shape=input_shape, learning_rate=learning_rate,\n",
        "            metrics=['accuracy'])\n",
        "  clf = erm.Method(model, evaluate, inputs=inputs, dtype=tf.float32, pos=pos)\n",
        "  clf.fit(ds_target['train'], ds_target['val'], ds_target['train'],\n",
        "          steps_per_epoch_val, **train_kwargs)\n",
        "  result_list.append(evaluate_clf())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IQGhzo7jj1XQ"
      },
      "outputs": [],
      "source": [
        "result[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GltTo_QHonGn"
      },
      "source": [
        "## Covariate Shift"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6ZL0qK7wYkbQ"
      },
      "source": [
        "Performs adaptation by weighting instances by $q(x)/p(x)$. We estimate the likelihood ratio using a neural network that discriminates between source and target domains."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "biISuTU3P22L"
      },
      "outputs": [],
      "source": [
        "method = 'covar'\n",
        "result_list = []\n",
        "input_shape = (x_dim, )\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  \n",
        "  model = mlp(num_classes=num_classes, width=width, input_shape=input_shape,\n",
        "              learning_rate=learning_rate,\n",
        "              metrics=['accuracy'])\n",
        "\n",
        "  domain_discriminator = mlp(num_classes=2, width=width, input_shape=input_shape,\n",
        "                      learning_rate=learning_rate,\n",
        "                      loss=\"sparse_categorical_crossentropy\",\n",
        "                      metrics=['accuracy'])\n",
        "  \n",
        "  clf = cov.Method(model, domain_discriminator, evaluate,\n",
        "                 dtype=tf.float32, pos=pos)\n",
        "  clf.fit(ds_source['train'], ds_source['val'], ds_target['train'],\n",
        "          steps_per_epoch_val, **train_kwargs)\n",
        "  result_list.append(evaluate_clf())\n",
        "result[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8tq6xCD1U9Wy"
      },
      "source": [
        "## Label Shift (Oracle Weights)\n",
        "\n",
        "Label shift adjustment with oracle access to $q(y)$, where $q(y)$ is the probability distribution of y in the target domain.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p0DZ4l2QQcav"
      },
      "outputs": [],
      "source": [
        "method = 'label_oracle'\n",
        "result_list = []\n",
        "input_shape = (x_dim, )\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "\n",
        "  model = mlp(num_classes=num_classes, width=width, input_shape=input_shape,\n",
        "                     learning_rate=learning_rate,\n",
        "                     metrics=['accuracy'])\n",
        "  \n",
        "  clf = label.Method(model, evaluate, num_classes=num_classes,\n",
        "                   dtype=tf.float32, pos=pos)\n",
        "  clf.fit(ds_source['train'], ds_source['val'], ds_target['train'],\n",
        "          steps_per_epoch_val, **train_kwargs)\n",
        "  result_list.append(evaluate_clf())\n",
        "result[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sDQ7y1q_M0pf"
      },
      "source": [
        "## Black box label shift adjustment (BBSE)\n",
        "\n",
        "Label shift adjustment in which with the likelihood ratio q(y)/p(y) is estimated using the confusion matrix approach, as in Lipton 2018 (https://arxiv.org/abs/1802.03916)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2q23fz5yWxk3"
      },
      "outputs": [],
      "source": [
        "method = 'bbse'\n",
        "\n",
        "result_list = []\n",
        "input_shape = (x_dim, )\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "  \n",
        "  x2z = mlp(num_classes=num_classes, width=width, input_shape=input_shape,\n",
        "                     learning_rate=learning_rate,\n",
        "                     metrics=['accuracy'])\n",
        "  model = mlp(num_classes=num_classes, width=width, input_shape=input_shape,\n",
        "                     learning_rate=learning_rate,\n",
        "                     metrics=['accuracy'])\n",
        "  \n",
        "  clf = bbse.Method(model, x2z, evaluate, num_classes=num_classes,\n",
        "                   dtype=tf.float32, pos=pos)\n",
        "\n",
        "  clf.fit(ds_source['train'], ds_source['val'], ds_target['train'],\n",
        "          steps_per_epoch_val, **train_kwargs)\n",
        "\n",
        "  result_list.append(evaluate_clf())\n",
        "result[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FejAurjERdSy"
      },
      "source": [
        "## Latent shift adaptation with known U\n",
        "\n",
        "Adaptation using equation (1) in the paper (https://arxiv.org/abs/2212.11254), assuming access to U in the source domain."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ceg1zAc9Xlij"
      },
      "outputs": [],
      "source": [
        "method = 'known_u'\n",
        "\n",
        "result_list = []\n",
        "input_shape = (x_dim, )\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "\n",
        "  z = 'u'\n",
        "  x2z_model = mlp(num_classes=u_dim, width=width, input_shape=(x_dim, ),\n",
        "                            learning_rate=learning_rate,\n",
        "                            metrics=['accuracy'])\n",
        "\n",
        "  model = mlp(num_classes=num_classes, width=width, input_shape=(x_dim + u_dim, ),\n",
        "                      learning_rate=learning_rate,\n",
        "                      metrics=['accuracy'])\n",
        "  \n",
        "  clf = bbsez.Method(model, x2z_model, evaluate, num_classes=num_classes,\n",
        "                   dtype=tf.float32, pos=pos, confounder=z)\n",
        "  clf.fit(ds_source['train'], ds_source['val'], ds_target['train'],\n",
        "          steps_per_epoch_val, **train_kwargs)\n",
        "\n",
        "  result_list.append(evaluate_clf())\n",
        "result[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YZ8YiKiXdUaR"
      },
      "source": [
        "## VAE\n",
        "\n",
        "When U is unknown, we estimate it using a variational autoencoder. 'Graph-based' VAE enforces the structure of the graph in the paper in the decoder. 'Vanilla' VAE does not enforce any structure."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fP4_Tk-Ax8Xu"
      },
      "outputs": [],
      "source": [
        "latent_dim = 10  #@param {type:\"number\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QGuje5hdCODw"
      },
      "source": [
        "### Graph-Based"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rbNzHyldYaCV"
      },
      "outputs": [],
      "source": [
        "method = 'vae_graph'\n",
        "result_list = []\n",
        "input_shape = (x_dim, )\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "\n",
        "  encoder = mlp(num_classes=latent_dim, width=width,\n",
        "                input_shape=(x_dim + c_dim + w_dim + num_classes,),\n",
        "                learning_rate=learning_rate,\n",
        "                metrics=['accuracy'])\n",
        "\n",
        "  model_x2u = mlp(num_classes=latent_dim, width=width, input_shape=(x_dim,),\n",
        "                  learning_rate=learning_rate,\n",
        "                  metrics=['accuracy'])\n",
        "  model_xu2y = mlp(num_classes=num_classes, width=width,\n",
        "                  input_shape=(x_dim + latent_dim,),\n",
        "                  learning_rate=learning_rate,\n",
        "                  metrics=['accuracy'])\n",
        "  vae_opt = tf.keras.optimizers.RMSprop(learning_rate=1e-4)\n",
        "  \n",
        "  dims = mlc.ConfigDict()\n",
        "  dims.x = x_dim\n",
        "  dims.y = num_classes\n",
        "  dims.c = c_dim\n",
        "  dims.w = u_dim\n",
        "  dims.u = u_dim\n",
        "\n",
        "  clf = gumbelmax_graph.Method(encoder, width, vae_opt,\n",
        "                      model_x2u, model_xu2y, \n",
        "                      dims, latent_dim, None,\n",
        "                      kl_loss_coef=3,\n",
        "                      num_classes=num_classes, evaluate=evaluate,\n",
        "                      dtype=tf.float32, pos=pos)\n",
        "\n",
        "  clf.fit(ds_source['train'], ds_source['val'], ds_target['train'],\n",
        "          steps_per_epoch_val, **train_kwargs)\n",
        "\n",
        "  result_list.append(evaluate_clf())\n",
        "result[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ivs0BoOyCLKr"
      },
      "source": [
        "### Vanilla"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "diECUFXEaPpL"
      },
      "outputs": [],
      "source": [
        "method = 'vae_vanilla'\n",
        "result_list = []\n",
        "input_shape = (x_dim, )\n",
        "for seed in range(ITERATIONS):\n",
        "  print(f'Iteration: {seed}')\n",
        "  tf.random.set_seed(seed)\n",
        "  np.random.seed(seed)\n",
        "\n",
        "  encoder = mlp(num_classes=latent_dim, width=width,\n",
        "                input_shape=(x_dim + c_dim + w_dim + num_classes,),\n",
        "                learning_rate=learning_rate,\n",
        "                metrics=['accuracy'])\n",
        "  decoder = mlp(num_classes=x_dim + c_dim + w_dim + num_classes,\n",
        "                width=width, input_shape=(latent_dim,),\n",
        "                learning_rate=learning_rate,\n",
        "                metrics=['accuracy'])\n",
        "\n",
        "  model_x2u = mlp(num_classes=latent_dim, width=width, input_shape=(x_dim,),\n",
        "                learning_rate=learning_rate,\n",
        "                metrics=['accuracy'])\n",
        "  model_xu2y = mlp(num_classes=num_classes, width=width,\n",
        "                  input_shape=(x_dim + latent_dim,),\n",
        "                learning_rate=learning_rate,\n",
        "                metrics=['accuracy'])\n",
        "  vae_opt = tf.keras.optimizers.Adam(learning_rate=1e-4)\n",
        "  \n",
        "  clf = gumbelmax_vanilla.Method(encoder, decoder, vae_opt,\n",
        "                        model_x2u, model_xu2y, kl_loss_coef=3,\n",
        "                        num_classes=num_classes, evaluate=evaluate,\n",
        "                        dtype=tf.float32, pos=pos, dims=dims)\n",
        "  clf.fit(ds_source['train'], ds_source['val'], ds_target['train'],\n",
        "          steps_per_epoch_val, **train_kwargs)\n",
        "\n",
        "  result_list.append(evaluate_clf())\n",
        "result[method] = pd.concat([pd.DataFrame(elem).rename_axis('eval_set').reset_index().assign(iteration=i) for i, elem in enumerate(result_list)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dwfv4K_g4a1D"
      },
      "source": [
        "# Results"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jXiclB6yVn1D"
      },
      "outputs": [],
      "source": [
        "result_df=pd.concat(result).reset_index(level=1, drop=True).rename_axis('method').reset_index()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0HYsJdEMqS5S"
      },
      "outputs": [],
      "source": [
        "metrics = list(evals_sklearn.keys())\n",
        "metrics"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rjedN_K2WMMo"
      },
      "outputs": [],
      "source": [
        "results_mean = result_df.groupby(['method', 'eval_set'])[metrics].mean()\n",
        "results_sd = result_df.groupby(['method', 'eval_set'])[metrics].std()\n",
        "results_formatted = results_mean.applymap(lambda x: '{:.4f}'.format(x)) + u' \\u00B1 ' + results_sd.applymap(lambda x: '{:.4f}'.format(x))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ztxz18SiYkor"
      },
      "outputs": [],
      "source": [
        "results_formatted"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "szzPIAzRI4jr"
      },
      "outputs": [],
      "source": [
        "folder_id = './tmp_data'\n",
        "os.makedirs(folder_id, exist_ok=True)\n",
        "result_df.to_csv(os.path.join(folder_id, 'method_comparison_table.csv'), index=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wzlB63rm7a1n"
      },
      "outputs": [],
      "source": [
        "results_long = results_formatted.reset_index().melt(id_vars=['method', 'eval_set'], value_vars=metrics, var_name='metric', value_name='performance')\n",
        "print(results_long.head())\n",
        "results_pivot = results_long.pivot(index=['method', 'metric'], columns='eval_set')\n",
        "display(results_pivot.query('metric == \"cross-entropy\"'))\n",
        "display(results_pivot.query('metric == \"accuracy\"'))\n",
        "display(results_pivot.query('metric == \"auc\"'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lPGH89j97orl"
      },
      "outputs": [],
      "source": [
        "results_pivot_concat = pd.concat([results_pivot, results_pivot_sk]).sort_index()\n",
        "display(results_pivot_concat.query('metric == \"cross-entropy\"'))\n",
        "display(results_pivot_concat.query('metric == \"accuracy\"'))\n",
        "display(results_pivot_concat.query('metric == \"auc\"'))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DmoulDI07rfj"
      },
      "outputs": [],
      "source": [
        "method_format_dict = {\n",
        "    'erm_source': 'ERM-SOURCE',\n",
        "    'covar': 'COVAR',\n",
        "    'label_oracle': 'LABEL',\n",
        "    'bbse': 'BBSE', \n",
        "    'vae_graph': 'LSA-WAE (ours)',\n",
        "    'vae_vanilla': 'LSA-WAE-V',\n",
        "    'known_u': 'LSA-ORACLE',\n",
        "    'erm_target': 'ERM-TARGET'\n",
        "    }\n",
        "method_format_df = pd.DataFrame(method_format_dict, index = ['Method']).transpose().rename_axis('method').reset_index()\n",
        "method_order = ['ERM-SOURCE', 'COVAR', 'LABEL', 'BBSE', 'LSA-WAE (ours)', 'LSA-WAE-V', 'LSA-ORACLE', 'ERM-TARGET']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fdt7jo408Ajc"
      },
      "outputs": [],
      "source": [
        "results_to_print = results_pivot_concat.droplevel(0, axis=1).reset_index()\n",
        "results_to_print = results_to_print.merge(method_format_df).drop('method', axis=1).set_index('Method')\n",
        "result_to_print = results_to_print.loc[method_order]\n",
        "result_to_print"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Jq2fZ5d18LVA"
      },
      "outputs": [],
      "source": [
        "filename = \"simulation_continuous_10_seeds_{metric}.txt\"\n",
        "caption_text = \"{metric}\"\n",
        "for metric in ['auc', 'cross-entropy', 'accuracy']:\n",
        "  temp = result_to_print.query('metric == @metric').drop('metric', axis=1)\n",
        "  with open(os.path.join(folder_id, filename.format(metric=metric)), \"w\") as handle:\n",
        "    temp.to_latex(\n",
        "        buf=handle,\n",
        "        caption=caption_text.format(metric=metric)\n",
        "    )"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "continuous_simulation_study.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.9"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
